package com.example.demo.redis;

import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.*;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.Expression;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.stereotype.Component;

import java.lang.reflect.Method;
import java.util.UUID;

/**
 * redis分布式锁aop
 *
 */
@Aspect
@Component
@Slf4j
public class RedisLockAspect {

    /**
     * 线程变量——锁的key
     */
    private ThreadLocal<String> threadLocalKey = new ThreadLocal<>();

    /**
     * 线程变量——锁的value
     */
    private ThreadLocal<String> threadLocalValue = new ThreadLocal<>();

    /**
     * 线程变量——注解层数，用于保证只有最外层的注解才能生效
     */
    private ThreadLocal<Integer> threadLocalCount = new ThreadLocal<>();

    @Autowired
    private RedisLockService redisLockService;

    @Pointcut("@annotation(com.example.demo.redis.RedisLock)")
    public void pointCut() {

    }

    @Before("pointCut()&&@annotation(redisLock)")
    public void beforeHandle(JoinPoint joinPoint, RedisLock redisLock) {
        //避免threadLocalCount发生空指针问题，如果是null值设置为0
        if (threadLocalCount.get() == null) {
            threadLocalCount.set(0);
        }
        if (isFirst()) {
            //注解层数+1
            threadLocalCount.set(threadLocalCount.get()+1);
            String key = this.getRedisKey(joinPoint, redisLock);
            String value = Thread.currentThread().getName() +"-"+ UUID.randomUUID();
            long expire = redisLock.timeout();
//            boolean lock = redisLockService.lock(key, value, expire);
            boolean lock = redisLockService.lock(key, value);
            if (lock) {
                //加锁成功
                threadLocalKey.set(key);
                threadLocalValue.set(value);
            }else{
                //加锁失败
                throw new RedisLockException(redisLock.msg());
            }
        }else{
            //注解层数+1
            threadLocalCount.set(threadLocalCount.get()+1);
        }
    }

    @AfterReturning("pointCut()")
    public void afterHandle(JoinPoint joinPoint) {
        unlock();
    }

    @AfterThrowing(pointcut = "pointCut()", throwing = "e")
    public void afterThrowable(JoinPoint joinPoint, Throwable e) {
        unlock();
    }

    /**
     * 解锁操作
     */
    private void unlock() {
        //注解层数-1
        threadLocalCount.set(threadLocalCount.get()-1);
        if (isFirst()) {
            //进行解锁操作
            String key = threadLocalKey.get();
            String value = threadLocalValue.get();
            redisLockService.unlock(key, value);
            //进行线程remove操作
            removeThreadLocal();
        }
    }

    /**
     * 线程变量清理（线程池重复使用线程，需要remove一下，防止脏数据带到下一个线程）
     */
    private void removeThreadLocal() {
        threadLocalCount.remove();
        threadLocalKey.remove();
        threadLocalValue.remove();
    }

    /**
     * 通过threadLocalCount值是否为0判断是否是最外层的注解
     * @return
     */
    private boolean isFirst() {
        return threadLocalCount.get() == 0;
    }

    /**
     * 获取redis缓存的key
     * @param joinPoint
     * @param redisLock
     * @return
     */
    private String getRedisKey(JoinPoint joinPoint, RedisLock redisLock) {
        //获取注解上的key
        String key = redisLock.key();

        //使用SpringEL表达式解析注解上的key
        SpelExpressionParser parser = new SpelExpressionParser();
        Expression expression = parser.parseExpression(key);
        //获取方法入参
        Object[] parameterValues = joinPoint.getArgs();
        //获取方法形参
        MethodSignature signature = (MethodSignature)joinPoint.getSignature();
        Method method = signature.getMethod();
        DefaultParameterNameDiscoverer nameDiscoverer = new DefaultParameterNameDiscoverer();
        String[] parameterNames = nameDiscoverer.getParameterNames(method);
        if (parameterNames == null || parameterNames.length == 0) {
            //方法没有入参,直接返回注解上的key
            return key;
        }
        //解析表达式
        EvaluationContext evaluationContext = new StandardEvaluationContext();
        // 给上下文赋值
        for(int i = 0 ; i < parameterNames.length ; i++) {
            evaluationContext.setVariable(parameterNames[i], parameterValues[i]);
        }
        try {
            Object expressionValue = expression.getValue(evaluationContext);
            if (expressionValue != null && !"".equals(expressionValue.toString())) {
                //返回el解析后的key
                return expressionValue.toString();
            }else{
                //使用注解上的key
                return key;
            }
        } catch (Exception e) {
            //解析失败，默认使用注解上的key
            return key;
        }
    }

}

